import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from pathos.multiprocessing import Pool, cpu_count
import seaborn as sns
import scipy
import pathos
from subprocess import Popen, PIPE
import getpass
import emcee
from functools import partial

from tqdm.notebook import tqdm
from operator import add
from functools import reduce
from contextlib import closing

from jm.library.data_helper import filter_data, date_to_str, is_between, cols_by_date,     is_weakly_less_than, is_greater_than
from jm.library.likelihood import Likelihood, Constraint
from jm.library.simulator import TargetPrioritiesTargetActions, CounterfactualConfigs
from jm.library.additional_helpers import plot_over_time, payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols, \
    action_nomedida_cols, cumul_binary_fwdiff_payment_cols, binary_fwdiff_payment_cols
import warnings
warnings.filterwarnings("ignore")


df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')


df_status[relative_payment_cols] = df_status[payment_cols].divide(
                df_status['total_due'], axis=0)


df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
    df_status['total_due'], axis=0)


td_pctile99 = df_status['total_due'].quantile([0.99]).values[0]


list_tax_due = [99, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1250, 1500, 2000, 60000]


# In[ ]:


map_payments = {}

for i, t in enumerate(list_tax_due[:-1]):
    l, u = list_tax_due[i:i+2]
    payments = filter_data(
        df_status, total_due=is_between(l, u))[
        relative_fwdiff_payment_cols].to_numpy().flatten()
    positive_payments = payments[(payments > 0)]
    map_payments[t] = positive_payments


# In[ ]:


df_treatment = filter_data(df_status, assignment_to_treatment=1)


# In[ ]:


PARAM_NAMES = ('some_payment',
 'payment',
 'priority_G1',
 'priority_G2',
 'priority_G3',
 'action_valor',
 'action_rec1',
 'action_medida',
 'covariate',
 'G1_rec1',
 'G1_medida',
 's_shape_low',
 's_shape_high',
 'type_sdev')


# In[ ]:


backend = emcee.backends.HDFBackend(
    'estimation/parameter_estimates/jesus_maria_mcmc_G1actioninteraction.h5')
state = backend.get_last_sample()

ndiscard = 3000
if os.environ.get("test_size")=="true":
    ndiscard = 0
nthin = 1

chain_array = backend.get_chain(flat=False)
chain_array = chain_array[:, :, :]
flat_len = int(
    128 * (chain_array.shape[0]-ndiscard) / nthin)
chain_array = chain_array[
    ndiscard::nthin,:,:]

chain_array_flat = chain_array.reshape(
    flat_len, chain_array.shape[2])

df_params = pd.DataFrame(chain_array_flat, columns=PARAM_NAMES)
params = np.array(list(df_params.mean()))

sigma_theta = params[-1]

n = len(df_treatment)

num_sim = 400
if os.environ.get("test_size")=="true":
    num_sim = 1

configs = CounterfactualConfigs()
initial_G1 = configs.initial_G1


df_treatment = df_treatment.sort_values(
    'effective_score', ascending=False, kind='mergesort')
state_0_selected_cols = ['total_due'] + [payment_cols[0], priority_cols[0], action_cols[0]] + [
    'prob_repayment_endo_covariates']


cols = [['payments_by_{}'.format(d), 'priority_by_{}'.format(d),
         'action_by_{}'.format(d)] for d in [_d.strftime('%Y-%m-%d') for _d
                                             in Likelihood.event_dates]]

cols = reduce(add, cols)

def run_simulator(sim_seed, initial_G1=initial_G1):
    sim, seed = sim_seed
    sim.reset()
    np.random.seed(seed)
    s0 = df_treatment.loc[:, state_0_selected_cols]
    s0[priority_cols[0]]="G3"
    s0.loc[s0[0:initial_G1].index, priority_cols[0]]="G1"
    s0['theta'] = sigma_theta * np.random.randn(n)
    s0 = s0.values
    list_states = [s0[:, [1, 2, 3]]]
    s = np.array(s0)
    for d in range(21):
        s = np.array(sim(s, d))
        list_states.append(s[:, [1, 2, 3]])
    sim_res = np.concatenate(list_states, axis=1)
    return pd.DataFrame(sim_res, columns=cols)


# In[ ]:


def seeded_simulator(sim, base_seed, num_sim):
    return list(zip(n * [ sim ], range(base_seed, base_seed + num_sim)))


# In[ ]:


# 'base sim - actual'
initial_G1 = 350
base_config = configs.longer_G1_more_G1_actual
np.random.seed(1234)
base_simulator = TargetPrioritiesTargetActions(
    params, map_payments, n=n, config=base_config, list_tax_due=list_tax_due, timetrend=False)

with closing(Pool(cpu_count()-2)) as pool:
    base_res = pool.map(run_simulator, seeded_simulator(base_simulator, 1234, num_sim))


# In[ ]:


for j in range(0,len(base_res)):
    base_res[j].loc[:, payment_cols] = base_res[j].loc[:, payment_cols].fillna(0)
    base_res[j][binary_fwdiff_payment_cols] = base_res[j][
        payment_cols].diff(axis=1).shift(-1, axis=1).rename(
            columns=dict(zip(payment_cols, binary_fwdiff_payment_cols)))
    base_res[j].loc[:, binary_fwdiff_payment_cols] = base_res[j].loc[:, binary_fwdiff_payment_cols] > 0
    for k,col in enumerate(binary_fwdiff_payment_cols):
        base_res[j][cumul_binary_fwdiff_payment_cols[k]] = base_res[j][binary_fwdiff_payment_cols[0:k+1]].sum(axis=1)

for j in range(0,len(base_res)):
    base_res[j] = base_res[j].merge(df_treatment['total_due'].reset_index(),
                                          left_index=True, right_index=True)
payments_base_res = [df[cumul_binary_fwdiff_payment_cols].sum(axis=0) for df in base_res]
df_base_res = pd.DataFrame(payments_base_res)
df_base_res.columns = Likelihood.event_dates


# In[ ]:


# 'control sim'
df_treatment = df_treatment.sort_values(
    'total_due', ascending=False, kind='mergesort')
np.random.seed(1234)

run_simulator_control = partial(run_simulator, initial_G1=0)

control_config = configs.control_config

params_control = params.copy()
params_control[2:5] = 3*[0]
params_control[9:11] = 2*[0]

control_simulator = TargetPrioritiesTargetActions(
    params_control, map_payments, n=n, config=control_config, list_tax_due=list_tax_due, timetrend=False)

np.random.seed(1234)
with closing(Pool(cpu_count()-2)) as pool:
    control_res = pool.map(run_simulator_control, seeded_simulator(control_simulator, 1234, num_sim))



for j in range(0,len(control_res)):
    control_res[j].loc[:, payment_cols] = control_res[j].loc[:, payment_cols].fillna(0)
    control_res[j][binary_fwdiff_payment_cols] = control_res[j][
        payment_cols].diff(axis=1).shift(-1, axis=1).rename(
            columns=dict(zip(payment_cols, binary_fwdiff_payment_cols)))
    control_res[j].loc[:, binary_fwdiff_payment_cols] = control_res[j].loc[:, binary_fwdiff_payment_cols] > 0
    for k,col in enumerate(binary_fwdiff_payment_cols):
        control_res[j][cumul_binary_fwdiff_payment_cols[k]] = control_res[j][binary_fwdiff_payment_cols[0:k]].sum(axis=1)

for j in range(0,len(control_res)):
    control_res[j] = control_res[j].merge(df_treatment['total_due'].reset_index(),
                                          left_index=True, right_index=True)
payments_control_res = [df[cumul_binary_fwdiff_payment_cols].sum(axis=0) for df in control_res]
df_control_res = pd.DataFrame(payments_control_res)
df_control_res.columns = Likelihood.event_dates


# ### Average Payment Per Event in Control v.s. Treatment

base_res_values = []
for j in range(0, len(base_res)):
    base_res_values.append(base_res[j]['payments_by_2021-09-06'].sum()
                           / base_res[j]['cumul_binary_fwdiff_payments_by_2021-09-06'].sum())

control_res_values = []
for j in range(0, len(control_res)):
    control_res_values.append(control_res[j]['payments_by_2021-09-06'].sum()
                              /control_res[j]['cumul_binary_fwdiff_payments_by_2021-09-06'].sum())

# #### <99th percentile

base_res_p99_values = []
for j in range(0, len(base_res)):
    base_res_p99_values.append(base_res[j][base_res[j]['total_due'] <= td_pctile99]['payments_by_2021-09-06'].sum()
                           / base_res[j][base_res[j]['total_due'] <= td_pctile99][
                               'cumul_binary_fwdiff_payments_by_2021-09-06'].sum())

control_res_p99_values = []
for j in range(0, len(control_res)):
    control_res_p99_values.append(control_res[j][control_res[j]['total_due'] <= td_pctile99]['payments_by_2021-09-06'].sum()
                              / control_res[j][control_res[j]['total_due'] <= td_pctile99][
                                  'cumul_binary_fwdiff_payments_by_2021-09-06'].sum())

column_names = ['treatment_num_event', 'control_num_event',
                'treatment_average_payment_per_event', 'control_average_payment_per_event',
                'treatment_average_payment_per_event_p99', 'control_average_payment_per_event_p99']

values = [df_base_res.mean().iloc[-1], df_control_res.mean().iloc[-1],
          pd.DataFrame(base_res_values).values.mean(), pd.DataFrame(control_res_values).values.mean(),
          pd.DataFrame(base_res_p99_values).values.mean(), pd.DataFrame(control_res_p99_values).values.mean()]

df_out = pd.DataFrame([values], columns=column_names)
df_out.to_csv('estimation/simulation_estimates/simulations_binaryevents_atmean.csv')
